-
Notifications
You must be signed in to change notification settings - Fork 30
add mamba causal-conv1d-update kernel #48
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
base: main
Are you sure you want to change the base?
Conversation
cache_seqlens: Optional[torch.Tensor] = None, | ||
conv_state_indices: Optional[torch.Tensor] = None, | ||
pad_slot_id: int = PAD_SLOT_ID, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are returning o but it is not listed here in your function signature?
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] | ||
in this case, the kernel will not process entries at | ||
indices 0 and 3 | ||
out: (batch, dim) or (batch, dim, seqlen) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also inconsistent - out? (vs o)
conv_state_indices=conv_state_indices, | ||
pad_slot_id=pad_slot_id, | ||
) | ||
return o |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit but not a fan of using o by itself... out or output etc. makes it more clear imo.
Hi @thoangtrvn - thanks for the update and sorry for the delay! |
Thanks @lessw2020 , I'll update base on your feedback. |
This is related to the prior PR: #47
This adds the second Triton kernel (decode stage) to be used in mamba-based model for the inference purpose.